{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "MTL_Train_Home_S.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Jeremy26/hydranets_course/blob/main/MTL_Train_Home_S.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "20Fk1u9m9xfl"
      },
      "source": [
        "# Welcome to the HydraNet Home Robot Workshop 🐸🐸🐸\n",
        "\n",
        "In this workshop, you're going to learn how to train a Neural Network that does **real-time semantic segmentation and monocular depth prediction**.\n",
        "\n",
        "![](https://d3i71xaburhd42.cloudfront.net/435d4b5c30f10753d277848a17baddebd98d3c31/2-Figure1-1.png)\n",
        "\n",
        "The Model is [a Multi-Task Learning algorithm designed by Vladimir Nekrasov](https://arxiv.org/pdf/1809.04766.pdf). The entire work is based on the **DenseTorch Library**, that you can find and use [here](https://github.com/DrSleep/DenseTorch). <p>\n",
        "\n",
        "**A note —** This notebook is adapting the Library with express authorization from the author for educational purpose.\n",
        "\n",
        "## Home Robot 🤖\n",
        "* In the previous workshop of the course, you learned how to design the model shown above, and to run it on the KITTI Dataset using pretrained weights. The **KITTI Dataset only has 200 examples of segmentation**. Therefore, the authors used a technique called Knowledge Distillation and finetuned using the Cityscape dataset.<p>\n",
        "\n",
        "* 👉 In our case, we'll use another dataset called the [NYUDv2 Dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html). **It contains 1449 annotated images for depth and segmentation**, which makes our life much simpler. —— Since this is an indoor dataset, we'll turn this project into a Home Robot Workshop!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Ix2A28T-ZT_"
      },
      "source": [
        "#1 — Imports\n",
        "\n",
        "We're going to import:\n",
        "*   The **Data from our previous notebook** (trained model, cmaps, ...)\n",
        "*   The **NYUD Dataset**, along with helper files, ground truth examples, and train/test split files\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f5QS8L_xCBLo"
      },
      "source": [
        "!wget https://hydranets-data.s3.eu-west-3.amazonaws.com/hydranets-data-2.zip && unzip -q hydranets-data-2.zip && mv hydranets-data-2/* . && rm hydranets-data-2.zip && rm -rf hydranets-data-2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "szGHb1YQSyJ_"
      },
      "source": [
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8snxsl9eACAs"
      },
      "source": [
        "# 1 — Dataset\n",
        "Let's begin with importing our data, and visualizing it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FsAvSppK_VLY"
      },
      "source": [
        "## Load and Visualize the Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UjQLkkLRULyU"
      },
      "source": [
        "import glob\n",
        "\n",
        "depth = #TODO: Load the Path\n",
        "seg = #TODO: Load the Path\n",
        "images = #TODO: Load the Path"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f3kIx0CPHeJ_"
      },
      "source": [
        "print(len(images))\n",
        "print(len(depth))\n",
        "print(len(seg))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XoZp3xq8jjYI"
      },
      "source": [
        "Since our dataset is a bit \"special\", we'll need a Color Map to read it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qIhWdE4Hd_Ul"
      },
      "source": [
        "CMAP = np.load('cmap_nyud.npy')\n",
        "print(len(CMAP))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6VzwxQ42D97D"
      },
      "source": [
        "idx = np.random.randint(0,len(seg))\n",
        "\n",
        "f, (ax0, ax1, ax2) = plt.subplots(1,3, figsize=(20,40))\n",
        "ax0.imshow(np.array(Image.open(images[idx])))\n",
        "ax0.set_title(\"Original\")\n",
        "ax1.imshow(np.array(Image.open(depth[idx])), cmap=\"plasma\")\n",
        "ax1.set_title(\"Depth\")\n",
        "ax2.imshow(CMAP[np.array(Image.open(seg[idx]))])\n",
        "ax2.set_title(\"Segmentation\")\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7RxtjtR_eRgt"
      },
      "source": [
        "print(np.unique(np.array(Image.open(seg[idx]))))\n",
        "print(len(np.unique(np.array(Image.open(seg[idx])))))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LrrBK7QY_Y_z"
      },
      "source": [
        "## Getting the DataLoader\n",
        "\n",
        "When training a model, 2 elements are going to be very important (compared to the last workshop):\n",
        "\n",
        "*   The Dataset\n",
        "*   The Training Loop, Loss, etc\n",
        "\n",
        "We already know how to design the model that does join depth and segmentation, so we only need to know how to train it!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SrAXPd5j63Cw"
      },
      "source": [
        "data_file = \"train_list_depth.txt\"\n",
        "\n",
        "with open(data_file, \"rb\") as f:\n",
        "    datalist = f.readlines()\n",
        "datalist = [x.decode(\"utf-8\").strip(\"\\n\").split(\"\\t\") for x in datalist]\n",
        "\n",
        "root_dir = \"/nyud\"\n",
        "masks_names = (\"segm\", \"depth\")\n",
        "\n",
        "print(datalist[0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "48VY3A057Abw"
      },
      "source": [
        "import os\n",
        "abs_paths = [os.path.join(\"nyud\", rpath) for rpath in datalist[0]]\n",
        "abs_paths"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ix5sznp37XdA"
      },
      "source": [
        "img_arr = #TODO: Load an RGB Image\n",
        "\n",
        "plt.imshow(img_arr)\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TfjiDlLr9yAs"
      },
      "source": [
        "masks_names = (\"segm\", \"depth\")\n",
        "\n",
        "for mask_name, mask_path in zip(masks_names, abs_paths[1:]):\n",
        "    #TODO: Load the Masks for Depth and Segmentation"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P4a5e6aLuHUC"
      },
      "source": [
        "from torch.utils.data import Dataset\n",
        "\n",
        "class HydranetDataset(Dataset):\n",
        "\n",
        "    def __init__(self, data_file, transform=None):\n",
        "        with open(data_file, \"rb\") as f:\n",
        "            datalist = f.readlines()\n",
        "        self.datalist = [x.decode(\"utf-8\").strip(\"\\n\").split(\"\\t\") for x in datalist]\n",
        "        self.root_dir = \"nyud\"\n",
        "        self.transform = transform\n",
        "        self.masks_names = (\"segm\", \"depth\")\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.datalist)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        abs_paths = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]] # Will output list of nyud/*/00000.png\n",
        "        sample = {}\n",
        "        sample[\"image\"] = #TODO: Copy/Paste your loaded code\n",
        "\n",
        "        for mask_name, mask_path in zip(self.masks_names, abs_paths[1:]):\n",
        "            #TODO: Copy/Paste your loaded code\n",
        "\n",
        "        if self.transform:\n",
        "            sample[\"names\"] = self.masks_names\n",
        "            sample = self.transform(sample)\n",
        "            # the names key can be removed by the transformation\n",
        "            if \"names\" in sample:\n",
        "                del sample[\"names\"]\n",
        "        return sample"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IzR6cnx5-6oT"
      },
      "source": [
        "### Normalization — Will be common to all images\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oUOgFLww2rxt"
      },
      "source": [
        "from utils import Normalise, RandomCrop, ToTensor, RandomMirror\n",
        "import torchvision.transforms as transforms"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NuKBfbOb-5ll"
      },
      "source": [
        "img_scale = 1.0 / 255\n",
        "depth_scale = 5000.0\n",
        "\n",
        "img_mean = np.array([0.485, 0.456, 0.406])\n",
        "img_std = np.array([0.229, 0.224, 0.225])\n",
        "\n",
        "normalise_params = [img_scale, img_mean.reshape((1, 1, 3)), img_std.reshape((1, 1, 3)), depth_scale,]\n",
        "\n",
        "transform_common = [Normalise(*normalise_params), ToTensor()]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_7B5LPe1_et-"
      },
      "source": [
        "### Transforms"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UdwYse7G_tip"
      },
      "source": [
        "crop_size = 400\n",
        "transform_train = transforms.Compose([RandomMirror(), RandomCrop(crop_size)] + transform_common)\n",
        "transform_val = transforms.Compose(transform_common)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7NvjcDfO_tjr"
      },
      "source": [
        "### DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0d82jskZ_vZo"
      },
      "source": [
        "train_batch_size = 4\n",
        "val_batch_size = 4\n",
        "train_file = \"train_list_depth.txt\"\n",
        "val_file = \"val_list_depth.txt\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kQogy3tk03TP"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "#TRAIN DATALOADER\n",
        "trainloader = #TODO: Call the Train DataLoader\n",
        "\n",
        "# VALIDATION DATALOADER\n",
        "valloader = #TODO: Call the Validation DataLoader"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LuR-bk-13cX6"
      },
      "source": [
        "# 2 — Creating the HydraNet\n",
        "We now have 2 DataLoaders: one for training, and one for validation/test. <p>\n",
        "\n",
        "In the next step, we're going to define our model, following the paper [Real-Time Joint Semantic Segmentation and Depth Estimation Using Asymmetric Annotations](https://arxiv.org/pdf/1809.04766.pdf) —— If you haven't read it yet, now is the time.\n",
        "<p>\n",
        "\n",
        "> ![](https://d3i71xaburhd42.cloudfront.net/435d4b5c30f10753d277848a17baddebd98d3c31/2-Figure1-1.png)\n",
        "\n",
        "Our model takes an input RGB image, make it go through an encoder, a lightweight refinenet decoder, and then has 2 heads, one for each task.<p>\n",
        "Things to note:\n",
        "* The only **convolutions** we'll need will be 3x3 and 1x1\n",
        "* We also need a **MaxPooling 5x5**\n",
        "* **CRP-Blocks** are implemented as Skip-Connection Operations\n",
        "* **Each Head is made of a 1x1 convolution followed by a 3x3 convolution**, only the data and the loss change there\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "025K4utrBqNy"
      },
      "source": [
        "## Building the Encoder — A MobileNetv2\n",
        "![](https://iq.opengenus.org/content/images/2020/11/conv_mobilenet_v2.jpg)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TpCvdaopBGqo"
      },
      "source": [
        "def conv3x3(in_channels, out_channels, stride=1, dilation=1, groups=1, bias=False):\n",
        "    \"\"\"3x3 Convolution: Depthwise: \n",
        "    https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html\n",
        "    \"\"\"\n",
        "    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=bias, groups=groups)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X26XhiFLBZD5"
      },
      "source": [
        "def conv1x1(in_channels, out_channels, stride=1, groups=1, bias=False,):\n",
        "    \"1x1 Convolution: Pointwise\"\n",
        "    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias, groups=groups)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "inBcLWf3jGvd"
      },
      "source": [
        "def batchnorm(num_features):\n",
        "    \"\"\"\n",
        "    https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html\n",
        "    \"\"\"\n",
        "    return nn.BatchNorm2d(num_features, affine=True, eps=1e-5, momentum=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vDJDvGiQSy2L"
      },
      "source": [
        "def convbnrelu(in_channels, out_channels, kernel_size, stride=1, groups=1, act=True):\n",
        "    \"conv-batchnorm-relu\"\n",
        "    if act:\n",
        "        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False),\n",
        "                             batchnorm(out_channels),\n",
        "                             nn.ReLU6(inplace=True))\n",
        "    else:\n",
        "        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False),\n",
        "                             batchnorm(out_channels))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gEXj3l0OBi1w"
      },
      "source": [
        "class InvertedResidualBlock(nn.Module):\n",
        "    \"\"\"Inverted Residual Block from https://arxiv.org/abs/1801.04381\"\"\"\n",
        "    def __init__(self, in_planes, out_planes, expansion_factor, stride=1):\n",
        "        super().__init__() # Python 3\n",
        "        intermed_planes = in_planes * expansion_factor\n",
        "        self.residual = (in_planes == out_planes) and (stride == 1) # Boolean/Condition\n",
        "        self.output = nn.Sequential(convbnrelu(in_planes, intermed_planes, 1),\n",
        "                                    convbnrelu(intermed_planes, intermed_planes, 3, stride=stride, groups=intermed_planes),\n",
        "                                    convbnrelu(intermed_planes, out_planes, 1, act=False))\n",
        "\n",
        "    def forward(self, x):\n",
        "        #residual = x\n",
        "        out = self.output(x)\n",
        "        if self.residual:\n",
        "            return (out + x)#+residual\n",
        "        else:\n",
        "            return out"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Hdl3llB-85fH"
      },
      "source": [
        "class MobileNetv2(nn.Module):\n",
        "    def __init__(self, return_idx=[6]):\n",
        "        super().__init__()\n",
        "        # expansion rate, output channels, number of repeats, stride\n",
        "        self.mobilenet_config = [\n",
        "        [1, 16, 1, 1],\n",
        "        [6, 24, 2, 2],\n",
        "        [6, 32, 3, 2],\n",
        "        [6, 64, 4, 2],\n",
        "        [6, 96, 3, 1],\n",
        "        [6, 160, 3, 2],\n",
        "        [6, 320, 1, 1],\n",
        "        ]\n",
        "        self.in_channels = 32  # number of input channels\n",
        "        self.num_layers = len(self.mobilenet_config)\n",
        "        self.layer1 = convbnrelu(3, self.in_channels, kernel_size=3, stride=2)\n",
        "    \n",
        "        self.return_idx = [1, 2, 3, 4, 5, 6]\n",
        "        #self.return_idx = make_list(return_idx)\n",
        "\n",
        "        c_layer = 2\n",
        "        for t, c, n, s in self.mobilenet_config:\n",
        "            layers = []\n",
        "            for idx in range(n):\n",
        "                layers.append(InvertedResidualBlock(self.in_channels,c,expansion_factor=t,stride=s if idx == 0 else 1,))\n",
        "                self.in_channels = c\n",
        "            setattr(self, \"layer{}\".format(c_layer), nn.Sequential(*layers))\n",
        "            c_layer += 1\n",
        "\n",
        "        self._out_c = [self.mobilenet_config[idx][1] for idx in self.return_idx] # Output: [24, 32, 64, 96, 160, 320]\n",
        "\n",
        "    def forward(self, x):\n",
        "        outs = []\n",
        "        x = self.layer1(x)\n",
        "        outs.append(self.layer2(x))  # 16, x / 2\n",
        "        outs.append(self.layer3(outs[-1]))  # 24, x / 4\n",
        "        outs.append(self.layer4(outs[-1]))  # 32, x / 8\n",
        "        outs.append(self.layer5(outs[-1]))  # 64, x / 16\n",
        "        outs.append(self.layer6(outs[-1]))  # 96, x / 16\n",
        "        outs.append(self.layer7(outs[-1]))  # 160, x / 32\n",
        "        outs.append(self.layer8(outs[-1]))  # 320, x / 32\n",
        "        return [outs[idx] for idx in self.return_idx]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YRmoJbmpkR0l"
      },
      "source": [
        "encoder = MobileNetv2()\n",
        "encoder.load_state_dict(torch.load(\"mobilenetv2-e6e8dd43.pth\"))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d7bg7yzbm5Cz"
      },
      "source": [
        "#print(encoder)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ymQKYmnjCEng"
      },
      "source": [
        "## Building the Decoder - A Multi-Task Lighweight RefineNet\n",
        "Paper: https://arxiv.org/pdf/1810.03272.pdf\n",
        "![](https://drsleep.github.io/images/rf_arch.png)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Fneuzuu_7Et"
      },
      "source": [
        "def make_list(x):\n",
        "    \"\"\"Returns the given input as a list.\"\"\"\n",
        "    if isinstance(x, list):\n",
        "        return x\n",
        "    elif isinstance(x, tuple):\n",
        "        return list(x)\n",
        "    else:\n",
        "        return [x]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N6Ia8Cm2BhN6"
      },
      "source": [
        "class CRPBlock(nn.Module):\n",
        "    \"\"\"CRP definition\"\"\"\n",
        "    def __init__(self, in_planes, out_planes, n_stages, groups=False):\n",
        "        super().__init__()\n",
        "        for i in range(n_stages):\n",
        "            setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'),\n",
        "                    conv1x1(in_planes if (i == 0) else out_planes,\n",
        "                            out_planes, stride=1,\n",
        "                            bias=False, groups=in_planes if groups else 1))\n",
        "        self.stride = 1\n",
        "        self.n_stages = n_stages\n",
        "        self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)\n",
        "\n",
        "    def forward(self, x):\n",
        "        top = x\n",
        "        for i in range(self.n_stages):\n",
        "            top = self.maxpool(top)\n",
        "            top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top)\n",
        "            x = top + x\n",
        "        return x"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oNglBOFL_rPV"
      },
      "source": [
        "class MTLWRefineNet(nn.Module):\n",
        "    def __init__(self, input_sizes, num_classes, agg_size=256, n_crp=4):\n",
        "        super().__init__()\n",
        "\n",
        "        stem_convs = nn.ModuleList()\n",
        "        crp_blocks = nn.ModuleList()\n",
        "        adapt_convs = nn.ModuleList()\n",
        "        heads = nn.ModuleList()\n",
        "\n",
        "        # Reverse since we recover information from the end\n",
        "        input_sizes = list(reversed((input_sizes)))\n",
        "\n",
        "        # No reverse for collapse indices is needed\n",
        "        self.collapse_ind = [[0, 1], [2, 3], 4, 5]\n",
        "\n",
        "        groups = [False] * len(self.collapse_ind)\n",
        "        groups[-1] = True\n",
        "\n",
        "        for size in input_sizes:\n",
        "            stem_convs.append(conv1x1(size, agg_size, bias=False))\n",
        "\n",
        "        for group in groups:\n",
        "            crp_blocks.append(self._make_crp(agg_size, agg_size, n_crp, group))\n",
        "            adapt_convs.append(conv1x1(agg_size, agg_size, bias=False))\n",
        "\n",
        "        self.stem_convs = stem_convs\n",
        "        self.crp_blocks = crp_blocks\n",
        "        self.adapt_convs = adapt_convs[:-1]\n",
        "\n",
        "        num_classes = list(num_classes)\n",
        "        for n_out in num_classes:\n",
        "            heads.append(\n",
        "                nn.Sequential(\n",
        "                    conv1x1(agg_size, agg_size, groups=agg_size, bias=False),\n",
        "                    nn.ReLU6(inplace=False),\n",
        "                    conv3x3(agg_size, n_out, bias=True),\n",
        "                )\n",
        "            )\n",
        "\n",
        "        self.heads = heads\n",
        "        self.relu = nn.ReLU6(inplace=True)\n",
        "\n",
        "    def forward(self, xs):\n",
        "        xs = list(reversed(xs))\n",
        "        for idx, (conv, x) in enumerate(zip(self.stem_convs, xs)):\n",
        "            xs[idx] = conv(x)\n",
        "\n",
        "        # Collapse layers\n",
        "        c_xs = [sum([xs[idx] for idx in make_list(c_idx)]) for c_idx in self.collapse_ind ]\n",
        "\n",
        "        for idx, (crp, x) in enumerate(zip(self.crp_blocks, c_xs)):\n",
        "            if idx == 0:\n",
        "                y = self.relu(x)\n",
        "            else:\n",
        "                y = self.relu(x + y)\n",
        "            y = crp(y)\n",
        "            if idx < (len(c_xs) - 1):\n",
        "                y = self.adapt_convs[idx](y)\n",
        "                y = F.interpolate(\n",
        "                    y,\n",
        "                    size=c_xs[idx + 1].size()[2:],\n",
        "                    mode=\"bilinear\",\n",
        "                    align_corners=True,\n",
        "                )\n",
        "\n",
        "        outs = []\n",
        "        for head in self.heads:\n",
        "            outs.append(head(y))\n",
        "        return outs\n",
        "\n",
        "    @staticmethod\n",
        "    def _make_crp(in_planes, out_planes, stages, groups):\n",
        "        # Same as previous, but showing the use of a @staticmethod\n",
        "        layers = [CRPBlock(in_planes, out_planes, stages, groups)]\n",
        "        return nn.Sequential(*layers)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mf_5m6FGpXRR"
      },
      "source": [
        "num_classes = (40, 1)\n",
        "decoder = MTLWRefineNet(encoder._out_c, num_classes)\n",
        "#print(decoder)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-gHOn86dCbJQ"
      },
      "source": [
        "# 3 — Train the Model\n",
        "\n",
        "Now that we've define our encoder and decoder. We are ready to train our model on the NYUDv2 Dataset.\n",
        "\n",
        "Here's what we'll need:\n",
        "\n",
        "*   Functions like **train() and valid()**\n",
        "*   **An Optimizer and a Loss Function**\n",
        "*   **Hyperparameters** such as Weight Decay, Momentum, Learning Rate, Epochs, ...\n",
        "\n",
        "Doesn't sound so bad, does it?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mh3thOQ9tIH3"
      },
      "source": [
        "## Loss Function\n",
        "\n",
        "Let's begin with the Loss and Optimization we'll need.\n",
        "\n",
        "* The **Segmentation Loss** is the **Cross Entropy Loss**, working as a per-pixel classification function with 15 or so classes.\n",
        "\n",
        "* The **Depth Loss** will be the **Inverse Huber Loss**."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NZwy2ATBtHQf"
      },
      "source": [
        "from utils import InvHuberLoss\n",
        "\n",
        "ignore_index = 255\n",
        "ignore_depth = 0\n",
        "\n",
        "crit_segm = #TODO: Define the Loss for Segmentation\n",
        "crit_depth = #TODO: Define the Loss for Depth"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Und9VjEtt73H"
      },
      "source": [
        "## Optimizer\n",
        "For the optimizer, we'll use the **Stochastic Gradient Descent**. We'll also add techniques such as weight decay or momentum."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pKEiyVXGt1To"
      },
      "source": [
        "lr_encoder = 1e-2\n",
        "lr_decoder = 1e-3\n",
        "momentum_encoder = 0.9\n",
        "momentum_decoder = 0.9\n",
        "weight_decay_encoder = 1e-5\n",
        "weight_decay_decoder = 1e-5"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "31Imhyw6t1az"
      },
      "source": [
        "optims = #TODO: Create a List of 2 Optimizers: One for the encoder, one for the decoder"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RuueL8XVvaFN"
      },
      "source": [
        "## Model Definition & State Loading"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "32DXd6vfvcIL"
      },
      "source": [
        "n_epochs = 1000"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0aelXUe9wtbg"
      },
      "source": [
        "from model_helpers import Saver, load_state_dict\n",
        "import operator \n",
        "import json\n",
        "import logging\n",
        "\n",
        "init_vals = (0.0, 10000.0)\n",
        "comp_fns = [operator.gt, operator.lt]\n",
        "ckpt_dir = \"./\"\n",
        "ckpt_path = \"./checkpoint.pth.tar\"\n",
        "\n",
        "saver = Saver(\n",
        "    args=locals(),\n",
        "    ckpt_dir=ckpt_dir,\n",
        "    best_val=init_vals,\n",
        "    condition=comp_fns,\n",
        "    save_several_mode=all,\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sJjsbuthxPwB"
      },
      "source": [
        "hydranet = nn.DataParallel(nn.Sequential(encoder, decoder).cuda()) # Use .cpu() if you prefer a slow death\n",
        "\n",
        "print(\"Model has {} parameters\".format(sum([p.numel() for p in hydranet.parameters()])))\n",
        "\n",
        "start_epoch, _, state_dict = saver.maybe_load(ckpt_path=ckpt_path, keys_to_load=[\"epoch\", \"best_val\", \"state_dict\"],)\n",
        "load_state_dict(hydranet, state_dict)\n",
        "\n",
        "if start_epoch is None:\n",
        "    start_epoch = 0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ayNsDMhax-cp"
      },
      "source": [
        "print(start_epoch)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mre_kFpoxPMA"
      },
      "source": [
        "## Learning Rate Scheduler"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mEze71T0wQR9"
      },
      "source": [
        "opt_scheds = []\n",
        "for opt in optims:\n",
        "    opt_scheds.append(torch.optim.lr_scheduler.MultiStepLR(opt, np.arange(start_epoch + 1, n_epochs, 100), gamma=0.1))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NvPeyUELwe8e"
      },
      "source": [
        "## Training and Validation Loops\n",
        "\n",
        "Now, all we need to do is go through the Train and Validation DataLoaders, and train our model.\n",
        "\n",
        "It should look like this:\n",
        "```python\n",
        "for i in range(start_epoch, n_epochs):\n",
        "    for sched in opt_scheds:\n",
        "        sched.step(i)\n",
        "    hydranet.train() # Set to train mode    \n",
        "    train(...) # Call the train function\n",
        "\n",
        "    if i % val_every == 0:\n",
        "        model1.eval() # Set to Eval Mode\n",
        "        with torch.no_grad():\n",
        "            vals = validate(...) # Call the validate function\n",
        "```\n",
        "\n",
        "In the (...), we'll send our dataloader, loss functions, optimizers, and everything we've defined before.<p>\n",
        "\n",
        "Which means **we need a training and validate functions.**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XmAvw7eY0sQu"
      },
      "source": [
        "from utils import AverageMeter\n",
        "from tqdm import tqdm"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fI30C_Cqzc4G"
      },
      "source": [
        "def train(model, opts, crits, dataloader, loss_coeffs=(1.0,), grad_norm=0.0):\n",
        "    model.train()\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    loss_meter = AverageMeter()\n",
        "    pbar = tqdm(dataloader)\n",
        "\n",
        "    for sample in pbar:\n",
        "        loss = 0.0\n",
        "        input = #TODO: Get the Input\n",
        "        targets = #TODO: Get the Targets\n",
        "        \n",
        "        #FORWARD\n",
        "        outputs = #TODO: Run a Forward pass\n",
        "\n",
        "        for out, target, crit, loss_coeff in zip(outputs, targets, crits, loss_coeffs):\n",
        "            #TODO: Increment the Loss\n",
        "\n",
        "        # BACKWARD\n",
        "        #TODO: Zero Out the Gradients\n",
        "        #TODO: Call Loss.Backward\n",
        "\n",
        "        if grad_norm > 0.0:\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)\n",
        "        #TODO: Run one step\n",
        "\n",
        "        loss_meter.update(loss.item())\n",
        "        pbar.set_description(\n",
        "            \"Loss {:.3f} | Avg. Loss {:.3f}\".format(loss.item(), loss_meter.avg)\n",
        "        )"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qo5DK2vH0vcK"
      },
      "source": [
        "def validate(model, metrics, dataloader):\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model.eval()\n",
        "    for metric in metrics:\n",
        "        metric.reset()\n",
        "\n",
        "    pbar = tqdm(dataloader)\n",
        "\n",
        "    def get_val(metrics):\n",
        "        results = [(m.name, m.val()) for m in metrics]\n",
        "        names, vals = list(zip(*results))\n",
        "        out = [\"{} : {:4f}\".format(name, val) for name, val in results]\n",
        "        return vals, \" | \".join(out)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for sample in pbar:\n",
        "            # Get the Data\n",
        "            input = sample[\"image\"].float().to(device)\n",
        "            targets = [sample[k].to(device) for k in dataloader.dataset.masks_names]\n",
        "\n",
        "            #input, targets = get_input_and_targets(sample=sample, dataloader=dataloader, device=device)\n",
        "            targets = [target.squeeze(dim=1).cpu().numpy() for target in targets]\n",
        "\n",
        "            # Forward\n",
        "            outputs = model(input)\n",
        "            #outputs = make_list(outputs)\n",
        "\n",
        "            # Backward\n",
        "            for out, target, metric in zip(outputs, targets, metrics):\n",
        "                metric.update(\n",
        "                    F.interpolate(out, size=target.shape[1:], mode=\"bilinear\", align_corners=False)\n",
        "                    .squeeze(dim=1)\n",
        "                    .cpu()\n",
        "                    .numpy(),\n",
        "                    target,\n",
        "                )\n",
        "            pbar.set_description(get_val(metrics)[1])\n",
        "    vals, _ = get_val(metrics)\n",
        "    print(\"----\" * 5)\n",
        "    return vals"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YtPmGx3j2VKk"
      },
      "source": [
        "## Main Loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5FpRg81N205P"
      },
      "source": [
        "from utils import MeanIoU, RMSE"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ozf0W-0I1yVc"
      },
      "source": [
        "crop_size = 400\n",
        "batch_size = 4\n",
        "val_batch_size = 4\n",
        "val_every = 5\n",
        "loss_coeffs = (0.5, 0.5)\n",
        "\n",
        "#TODO: Define a Training Loop! (Good Luck!)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1MCe8myT4wtk"
      },
      "source": [
        "# Inference Challenge\n",
        "\n",
        "Now that your model is trained and checkpoint saved, try and **load an image from the test dataset and run your model on it**. Print the FPS.\n",
        "<p>\n",
        "\n",
        "**MEGA POINTS** — Load a video, and **implement a video pipeline** as we did on the previous workshop!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7PR-brzyAzIH"
      },
      "source": [
        "#Good Luck! If you have any good result, send it to jeremy@thinkautonomous.ai directly!"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}